from SegAgent import *
import os

if __name__ == '__main__':
    
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
        )
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
    segmentation_model = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

    while True:
        image_path = input("please input the image path:")
        target_object = input("please input the target object you want to segment:")
        base_name = os.path.basename(image_path)[:-4]
        
        mask = seg_agent_qwenvl(model, processor, segmentation_model, image_path, target_object)
        mask_uint8 = (mask * 255).astype(np.uint8)  # True -> 255, False -> 0

    # 转换为图像并保存
        mask_image = Image.fromarray(mask_uint8, mode='L')  # 'L' 表示灰度图
        path1 = f"output/{base_name}_mask.png"
        mask_image.save(path1)
        print(f'mask has been saved to {path1}')
        
        image = cv2.imread(image_path)  # BGR 格式
    
        mask_3channel = cv2.cvtColor(mask_uint8, cv2.COLOR_GRAY2BGR)

        # 创建一个颜色图像（比如红色 mask），与原图大小一致
        color_mask = np.zeros_like(image, dtype=np.uint8)
        color_mask[:] = (0, 0, 255)  # BGR 格式，红色

        # 使用 mask 选择出颜色区域
        color_mask = cv2.bitwise_and(color_mask, color_mask, mask=mask_uint8)

        # 将颜色 mask 叠加到原图
        output = cv2.addWeighted(image, 1.0, color_mask, 0.6, 0)  # 0.6 控制颜色透明度

        # 保存或显示结果
        path2 = f"output/{base_name}_overlaid_mask.jpg"
        cv2.imwrite(path2, output)
        print(f'the result has bee saved to {path2}')